#!/usr/bin/env python
from fastapi import FastAPI
from langchain.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langserve import add_routes

# ✅ 解决 openapi 500 报错关键：rebuild 模型
from langserve.validation import openaiInvokeRequest
openaiInvokeRequest.model_rebuild()

app = FastAPI(
    title="LangChain Server",
    version="1.0",
    description="A simple api server using Langchain's Runnable interfaces",
)

from fastapi.middleware.cors import CORSMiddleware

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
    expose_headers=["*"],
)

add_routes(
    app,
    ChatOpenAI(
        openai_api_base="https://api.siliconflow.cn/v1/",
        openai_api_key="sk-xxx",  # 注意生产环境记得移除 key
        model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
    ),
    path="/openai",
)

model = ChatOpenAI(
    openai_api_base="https://api.siliconflow.cn/v1/",
    openai_api_key="sk-xxx",
    model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
)
prompt = ChatPromptTemplate.from_template("tell me a joke about {topic}")
add_routes(app, prompt | model, path="/joke")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="localhost", port=8100)
